import pandas as pd
import numpy as np
from sklearn import datasets
import seaborn as sns
import matplotlib as plt

# 均值、标准差、最大值等统计数据
iris = datasets.load_iris()
X_train = pd.DataFrame(iris.data, columns=iris.feature_names)
print(X_train.describe())
print(X_train.corr())
heatmap = sns.heatmap(X_train.corr(), annot=True, cmap="RdYlBu")

plt.figure("cat")
plt.imshow(heatmap)
plt.show()